{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a81c428fc2d3"
      },
      "source": [
        "# 迁移学习和微调"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3e5a59f0aefd"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>     <a target=\"_blank\" href=\"https://tensorflow.google.cn/guide/keras/transfer_learning\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看</a>\n",
        "</td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/transfer_learning.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行</a>\n",
        "</td>\n",
        "  <td>     <a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/guide/keras/transfer_learning.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 上查看源代码</a>\n",
        "</td>\n",
        "  <td>     <a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/zh-cn/guide/keras/transfer_learning.ipynb\"><img src=\"https://tensorflow.google.cn/images/download_logo_32px.png\">下载笔记本</a>\n",
        "</td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8d4ac441b1fc"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9a7e9b92f963"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "00d4c41cfe2f"
      },
      "source": [
        "## 简介\n",
        "\n",
        "**迁移学习**包括获取从一个问题中学习到的特征，然后将这些特征用于新的类似问题。例如，来自已学会识别浣熊的模型的特征可能对建立旨在识别狸猫的模型十分有用。\n",
        "\n",
        "对于数据集中的数据太少而无法从头开始训练完整模型的任务，通常会执行迁移学习。\n",
        "\n",
        "在深度学习情境中，迁移学习最常见的形式是以下工作流：\n",
        "\n",
        "1. 从之前训练的模型中获取层。\n",
        "2. 冻结这些层，以避免在后续训练轮次中破坏它们包含的任何信息。\n",
        "3. 在已冻结层的顶部添加一些新的可训练层。这些层会学习将旧特征转换为对新数据集的预测。\n",
        "4. 在您的数据集上训练新层。\n",
        "\n",
        "最后一个可选步骤是**微调**，包括解冻上面获得的整个模型（或模型的一部分），然后在新数据上以极低的学习率对该模型进行重新训练。以增量方式使预训练特征适应新数据，有可能实现有意义的改进。\n",
        "\n",
        "首先，我们将详细介绍 Keras `trainable` API，它是大多数迁移学习和微调工作流的基础。\n",
        "\n",
        "随后，我们将演示一个典型工作流：先获得一个在 ImageNet 数据集上预训练的模型，然后在 Kaggle Dogs vs. Cats 分类数据集上对该模型进行重新训练。\n",
        "\n",
        "此工作流改编自[《Python 深度学习》](https://www.manning.com/books/deep-learning-with-python)和 2016 年的博文[“使用极少的数据构建强大的图像分类模型”](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fbf8630c325b"
      },
      "source": [
        "## 冻结层：了解 `trainable` 特性\n",
        "\n",
        "层和模型具有三个权重特性：\n",
        "\n",
        "- `weights` 是层的所有权重变量的列表。\n",
        "- `trainable_weights` 是需要进行更新（通过梯度下降）以尽可能减少训练过程中损失的权重列表。\n",
        "- `non_trainable_weights` 是不适合训练的权重列表。它们通常在正向传递过程中由模型更新。\n",
        "\n",
        "**示例：`Dense` 层具有 2 个可训练权重（内核与偏差）**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "407deab1855e"
      },
      "outputs": [],
      "source": [
        "layer = keras.layers.Dense(3)\n",
        "layer.build((None, 4))  # Create the weights\n",
        "\n",
        "print(\"weights:\", len(layer.weights))\n",
        "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
        "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "79fcb9cc960d"
      },
      "source": [
        "一般而言，所有权重都是可训练权重。唯一具有不可训练权重的内置层是 `BatchNormalization` 层。在训练期间，它使用不可训练权重跟踪其输入的平均值和方差。要了解如何在您自己的自定义层中使用不可训练权重，请参阅[从头开始编写新层的指南](https://keras.io/guides/making_new_layers_and_models_via_subclassing/)。\n",
        "\n",
        "**示例：`BatchNormalization` 层具有 2 个可训练权重和 2 个不可训练权重**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fbc87a09bc3c"
      },
      "outputs": [],
      "source": [
        "layer = keras.layers.BatchNormalization()\n",
        "layer.build((None, 4))  # Create the weights\n",
        "\n",
        "print(\"weights:\", len(layer.weights))\n",
        "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
        "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cddcdbf2bd5b"
      },
      "source": [
        "层和模型还具有布尔特性 `trainable`。此特性的值可以更改。将 `layer.trainable` 设置为 `False` 会将层的所有权重从可训练移至不可训练。这一过程称为“冻结”层：已冻结层的状态在训练期间不会更新（无论是使用 `fit()` 进行训练，还是使用依赖于 `trainable_weights` 来应用梯度更新的任何自定义循环进行训练时）。\n",
        "\n",
        "**示例：将 `trainable` 设置为 `False`**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "51bbc5d12742"
      },
      "outputs": [],
      "source": [
        "layer = keras.layers.Dense(3)\n",
        "layer.build((None, 4))  # Create the weights\n",
        "layer.trainable = False  # Freeze the layer\n",
        "\n",
        "print(\"weights:\", len(layer.weights))\n",
        "print(\"trainable_weights:\", len(layer.trainable_weights))\n",
        "print(\"non_trainable_weights:\", len(layer.non_trainable_weights))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "32904f9a58db"
      },
      "source": [
        "当可训练权重变为不可训练时，它的值在训练期间不再更新。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3c26c27a8291"
      },
      "outputs": [],
      "source": [
        "# Make a model with 2 layers\n",
        "layer1 = keras.layers.Dense(3, activation=\"relu\")\n",
        "layer2 = keras.layers.Dense(3, activation=\"sigmoid\")\n",
        "model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])\n",
        "\n",
        "# Freeze the first layer\n",
        "layer1.trainable = False\n",
        "\n",
        "# Keep a copy of the weights of layer1 for later reference\n",
        "initial_layer1_weights_values = layer1.get_weights()\n",
        "\n",
        "# Train the model\n",
        "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
        "model.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n",
        "\n",
        "# Check that the weights of layer1 have not changed during training\n",
        "final_layer1_weights_values = layer1.get_weights()\n",
        "np.testing.assert_allclose(\n",
        "    initial_layer1_weights_values[0], final_layer1_weights_values[0]\n",
        ")\n",
        "np.testing.assert_allclose(\n",
        "    initial_layer1_weights_values[1], final_layer1_weights_values[1]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "412d7d659aa1"
      },
      "source": [
        "请勿将 `layer.trainable` 特性与 `layer.__call__()` 中的 `training` 参数（此参数控制层是在推断模式还是训练模式下运行其前向传递）混淆。有关更多信息，请参阅 [Keras 常见问题解答](https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute)。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e6ccd3c7ab1a"
      },
      "source": [
        "## `trainable` 特性的递归设置\n",
        "\n",
        "如果在模型或具有子层的任何层上设置 `trainable = False`，则所有子层也将变为不可训练。\n",
        "\n",
        "**示例**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4235d0c69821"
      },
      "outputs": [],
      "source": [
        "inner_model = keras.Sequential(\n",
        "    [\n",
        "        keras.Input(shape=(3,)),\n",
        "        keras.layers.Dense(3, activation=\"relu\"),\n",
        "        keras.layers.Dense(3, activation=\"relu\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "model = keras.Sequential(\n",
        "    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation=\"sigmoid\"),]\n",
        ")\n",
        "\n",
        "model.trainable = False  # Freeze the outer model\n",
        "\n",
        "assert inner_model.trainable == False  # All layers in `model` are now frozen\n",
        "assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61535ba76727"
      },
      "source": [
        "## 典型的迁移学习工作流\n",
        "\n",
        "下面将介绍如何在 Keras 中实现典型的迁移学习工作流：\n",
        "\n",
        "1. 实例化一个基础模型并加载预训练权重。\n",
        "2. 通过设置 `trainable = False` 冻结基础模型中的所有层。\n",
        "3. 根据基础模型中一个（或多个）层的输出创建一个新模型。\n",
        "4. 在您的新数据集上训练新模型。\n",
        "\n",
        "请注意，另一种更轻量的工作流如下：\n",
        "\n",
        "1. 实例化一个基础模型并加载预训练权重。\n",
        "2. 通过该模型运行新的数据集，并记录基础模型中一个（或多个）层的输出。这一过程称为**特征提取**。\n",
        "3. 使用该输出作为新的较小模型的输入数据。\n",
        "\n",
        "第二种工作流有一个关键优势，即您只需在自己的数据上运行一次基础模型，而不是每个训练周期都运行一次。因此，它的速度更快，开销也更低。\n",
        "\n",
        "但是，第二种工作流存在一个问题，即它不允许您在训练期间动态修改新模型的输入数据，在进行数据扩充时，这种修改必不可少。当新数据集的数据太少而无法从头开始训练完整模型时，任务通常会使用迁移学习，在这种情况下，数据扩充非常重要。因此，在接下来的篇幅中，我们将专注于第一种工作流。\n",
        "\n",
        "下面是 Keras 中第一种工作流的样子：\n",
        "\n",
        "首先，实例化一个具有预训练权重的基础模型。\n",
        "\n",
        "```python\n",
        "base_model = keras.applications.Xception(\n",
        "    weights='imagenet',  # Load weights pre-trained on ImageNet.\n",
        "    input_shape=(150, 150, 3),\n",
        "    include_top=False)  # Do not include the ImageNet classifier at the top.\n",
        "```\n",
        "\n",
        "随后，冻结该基础模型。\n",
        "\n",
        "```python\n",
        "base_model.trainable = False\n",
        "```\n",
        "\n",
        "根据基础模型创建一个新模型。\n",
        "\n",
        "```python\n",
        "inputs = keras.Input(shape=(150, 150, 3))\n",
        "# We make sure that the base_model is running in inference mode here,\n",
        "# by passing `training=False`. This is important for fine-tuning, as you will\n",
        "# learn in a few paragraphs.\n",
        "x = base_model(inputs, training=False)\n",
        "# Convert features of shape `base_model.output_shape[1:]` to vectors\n",
        "x = keras.layers.GlobalAveragePooling2D()(x)\n",
        "# A Dense classifier with a single unit (binary classification)\n",
        "outputs = keras.layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "```\n",
        "\n",
        "在新数据上训练该模型。\n",
        "\n",
        "```python\n",
        "model.compile(optimizer=keras.optimizers.Adam(),\n",
        "              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "              metrics=[keras.metrics.BinaryAccuracy()])\n",
        "model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "736c99aea690"
      },
      "source": [
        "## 微调\n",
        "\n",
        "一旦模型在新数据上收敛，您就可以尝试解冻全部或部分基础模型，并以极低的学习率端到端地重新训练整个模型。\n",
        "\n",
        "这是可选的最后一个步骤，可能给您带来增量式改进。不过，它也可能导致快速过拟合，请牢记这一点。\n",
        "\n",
        "重要的是，只有在将具有冻结层的模型训练至收敛*后*，才能执行此步骤。如果将随机初始化的可训练层与包含预训练特征的可训练层混合使用，则随机初始化的层将在训练过程中引起非常大的梯度更新，而这将破坏您的预训练特征。\n",
        "\n",
        "在此阶段使用极低的学习率也很重要，因为与第一轮训练相比，您正在一个通常非常小的数据集上训练一个大得多的模型。因此，如果您应用较大的权重更新，则存在很快过拟合的风险。在这里，您只需要以增量方式重新调整预训练权重。\n",
        "\n",
        "下面是实现整个基础模型微调的方法：\n",
        "\n",
        "```python\n",
        "# Unfreeze the base model\n",
        "base_model.trainable = True\n",
        "\n",
        "# It's important to recompile your model after you make any changes\n",
        "# to the `trainable` attribute of any inner layer, so that your changes\n",
        "# are take into account\n",
        "model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate\n",
        "              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "              metrics=[keras.metrics.BinaryAccuracy()])\n",
        "\n",
        "# Train end-to-end. Be careful to stop before you overfit!\n",
        "model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)\n",
        "```\n",
        "\n",
        "**关于 `compile()` 和 `trainable` 的重要说明**\n",
        "\n",
        "在模型上调用 `compile()` 是为了“冻结”该模型的行为。这意味着在编译模型时，应当在该模型的整个生命周期中保留 `trainable` 特性值，直到再次调用 `compile` 为止。因此，如果您更改了任何 `trainable` 值，请确保在模型上再次调用 `compile()`，将这些变更考虑在内。\n",
        "\n",
        "**关于 `BatchNormalization` 层的重要说明**\n",
        "\n",
        "许多图像模型都包含 `BatchNormalization` 层。在各种能想到的数量上，该层都是一个特例。需要牢记以下几点。\n",
        "\n",
        "- `BatchNormalization` 包含 2 个会在训练过程中更新的不可训练权重。它们是跟踪输入的平均值和方差的变量。\n",
        "- 设置 `bn_layer.trainable = False` 时，`BatchNormalization` 层将以推断模式运行，并且不会更新其均值和方差统计信息。其他层一般不是这种情况，因为[权重可训练性和推断/训练模式是两个正交的概念](https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute)。但是对于 `BatchNormalization` 层，两者是绑定的。\n",
        "- 解冻包含 `BatchNormalization` 层的模型以进行微调时，应在调用基础模型时通过传递 `training=False` 来使 `BatchNormalization` 层保持在推断模式下。否则，应用于不可训练权重的更新将突然破坏模型学习到的内容。\n",
        "\n",
        "您将在本指南结尾处的端到端示例中看到这种模式的实际运行。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bce9ffc4e290"
      },
      "source": [
        "## 使用自定义训练循环进行迁移学习和微调\n",
        "\n",
        "如果您使用的是自己的低级训练循环，而不是 `fit()`，则工作流基本保持不变。在应用梯度更新时，您应当注意仅考虑列表 `model.trainable_weights`：\n",
        "\n",
        "```python\n",
        "# Create base model\n",
        "base_model = keras.applications.Xception(\n",
        "    weights='imagenet',\n",
        "    input_shape=(150, 150, 3),\n",
        "    include_top=False)\n",
        "# Freeze base model\n",
        "base_model.trainable = False\n",
        "\n",
        "# Create new model on top.\n",
        "inputs = keras.Input(shape=(150, 150, 3))\n",
        "x = base_model(inputs, training=False)\n",
        "x = keras.layers.GlobalAveragePooling2D()(x)\n",
        "outputs = keras.layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "optimizer = keras.optimizers.Adam()\n",
        "\n",
        "# Iterate over the batches of a dataset.\n",
        "for inputs, targets in new_dataset:\n",
        "    # Open a GradientTape.\n",
        "    with tf.GradientTape() as tape:\n",
        "        # Forward pass.\n",
        "        predictions = model(inputs)\n",
        "        # Compute the loss value for this batch.\n",
        "        loss_value = loss_fn(targets, predictions)\n",
        "\n",
        "    # Get gradients of loss wrt the *trainable* weights.\n",
        "    gradients = tape.gradient(loss_value, model.trainable_weights)\n",
        "    # Update the weights of the model.\n",
        "    optimizer.apply_gradients(zip(gradients, model.trainable_weights))\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4e63ba34ce1c"
      },
      "source": [
        "对于微调同样如此。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "852447087ba9"
      },
      "source": [
        "## 端到端示例：基于 Dogs vs. Cats 数据集微调图像分类模型\n",
        "\n",
        "为了巩固这些概念，我们先介绍一个具体的端到端迁移学习和微调示例。我们将加载在 ImageNet 上预训练的 Xception 模型，并将其用于 Kaggle Dogs vs. Cats 分类数据集。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ba75835e0de6"
      },
      "source": [
        "### 获取数据\n",
        "\n",
        "首先，我们使用 TFDS 来获取 Dogs vs. Cats 数据集。如果您拥有自己的数据集，则可能需要使用效用函数 `tf.keras.preprocessing.image_dataset_from_directory` 从磁盘上存档到类特定的文件夹中的一组图像来生成相似的有标签数据集对象。\n",
        "\n",
        "使用非常小的数据集时，迁移学习最实用。为了使数据集保持较小状态，我们将原始训练数据（25,000 个图像）的 40% 用于训练，10% 用于验证，10% 用于测试。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1a99f56934f7"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "\n",
        "tfds.disable_progress_bar()\n",
        "\n",
        "train_ds, validation_ds, test_ds = tfds.load(\n",
        "    \"cats_vs_dogs\",\n",
        "    # Reserve 10% for validation and 10% for test\n",
        "    split=[\"train[:40%]\", \"train[40%:50%]\", \"train[50%:60%]\"],\n",
        "    as_supervised=True,  # Include labels\n",
        ")\n",
        "\n",
        "print(\"Number of training samples: %d\" % tf.data.experimental.cardinality(train_ds))\n",
        "print(\n",
        "    \"Number of validation samples: %d\" % tf.data.experimental.cardinality(validation_ds)\n",
        ")\n",
        "print(\"Number of test samples: %d\" % tf.data.experimental.cardinality(test_ds))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9db548603642"
      },
      "source": [
        "下面是训练数据集中的前 9 个图像。如您所见，它们具有不同的大小。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "00c8cbd1de88"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "plt.figure(figsize=(10, 10))\n",
        "for i, (image, label) in enumerate(train_ds.take(9)):\n",
        "    ax = plt.subplot(3, 3, i + 1)\n",
        "    plt.imshow(image)\n",
        "    plt.title(int(label))\n",
        "    plt.axis(\"off\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "168c4a10c072"
      },
      "source": [
        "我们还可以看到标签 1 是“狗”，标签 0 是“猫”。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f749203cd740"
      },
      "source": [
        "### 标准化数据\n",
        "\n",
        "我们的原始图像有各种大小。另外，每个像素由 0 到 255 之间的 3 个整数值（RGB 色阶值）组成。这不太适合馈送神经网络。我们需要做下面两件事：\n",
        "\n",
        "- 标准化为固定图像大小。我们选择 150x150。\n",
        "- 在 -1 至 1 之间归一化像素值。我们将使用 `Normalization` 层作为模型本身的一部分来进行此操作。\n",
        "\n",
        "一般而言，与采用已预处理数据的模型相反，开发以原始数据作为输入的模型是一种良好的做法。原因在于，如果模型需要预处理的数据，则每次导出模型以在其他地方（在网络浏览器、移动应用中）使用时，都需要重新实现完全相同的预处理流水线。这很快就会变得非常棘手。因此，在命中模型之前，我们应当尽可能少地进行预处理。\n",
        "\n",
        "在这里，我们将在数据流水线中进行图像大小调整（因为深度神经网络只能处理连续的数据批次），并在创建模型时将其作为模型的一部分进行输入值缩放。\n",
        "\n",
        "我们将图像的大小调整为 150x150："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b3678f38e087"
      },
      "outputs": [],
      "source": [
        "size = (150, 150)\n",
        "\n",
        "train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n",
        "validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))\n",
        "test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "708bf9792a35"
      },
      "source": [
        "此外，我们对数据进行批处理并使用缓存和预提取来优化加载速度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "53ef9e6092e3"
      },
      "outputs": [],
      "source": [
        "batch_size = 32\n",
        "\n",
        "train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n",
        "validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)\n",
        "test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b60f852c462f"
      },
      "source": [
        "### 使用随机数据扩充\n",
        "\n",
        "当您没有较大的图像数据集时，通过将随机但现实的转换（例如随机水平翻转或小幅随机旋转）应用于训练图像来人为引入样本多样性是一种良好的做法。这有助于使模型暴露于训练数据的不同方面，同时减慢过拟合的速度。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "40b1e355b9c0"
      },
      "outputs": [],
      "source": [
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers\n",
        "\n",
        "data_augmentation = keras.Sequential(\n",
        "    [\n",
        "        layers.experimental.preprocessing.RandomFlip(\"horizontal\"),\n",
        "        layers.experimental.preprocessing.RandomRotation(0.1),\n",
        "    ]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6fa8ddeda36e"
      },
      "source": [
        "我们看一下经过各种随机转换后第一个批次的第一个图像是什么样："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9077f9fd022e"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "for images, labels in train_ds.take(1):\n",
        "    plt.figure(figsize=(10, 10))\n",
        "    first_image = images[0]\n",
        "    for i in range(9):\n",
        "        ax = plt.subplot(3, 3, i + 1)\n",
        "        augmented_image = data_augmentation(\n",
        "            tf.expand_dims(first_image, 0), training=True\n",
        "        )\n",
        "        plt.imshow(augmented_image[0].numpy().astype(\"int32\"))\n",
        "        plt.title(int(labels[0]))\n",
        "        plt.axis(\"off\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bc999e4672c3"
      },
      "source": [
        "## 构建模型\n",
        "\n",
        "现在，我们来构建一个遵循我们先前解释的蓝图的模型。\n",
        "\n",
        "注意：\n",
        "\n",
        "- 我们添加 `Normalization` 层以将输入值（最初在 `[0, 255]` 范围内）缩放到 `[-1, 1]` 范围。\n",
        "- 我们在分类层之前添加一个 `Dropout` 层，以进行正则化。\n",
        "- 我们确保在调用基础模型时传递 `training=False`，使其在推断模式下运行，这样，即使在我们解冻基础模型以进行微调后，batchnorm 统计信息也不会更新。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "07a2f9e9d817"
      },
      "outputs": [],
      "source": [
        "base_model = keras.applications.Xception(\n",
        "    weights=\"imagenet\",  # Load weights pre-trained on ImageNet.\n",
        "    input_shape=(150, 150, 3),\n",
        "    include_top=False,\n",
        ")  # Do not include the ImageNet classifier at the top.\n",
        "\n",
        "# Freeze the base_model\n",
        "base_model.trainable = False\n",
        "\n",
        "# Create new model on top\n",
        "inputs = keras.Input(shape=(150, 150, 3))\n",
        "x = data_augmentation(inputs)  # Apply random data augmentation\n",
        "\n",
        "# Pre-trained Xception weights requires that input be normalized\n",
        "# from (0, 255) to a range (-1., +1.), the normalization layer\n",
        "# does the following, outputs = (inputs - mean) / sqrt(var)\n",
        "norm_layer = keras.layers.experimental.preprocessing.Normalization()\n",
        "mean = np.array([127.5] * 3)\n",
        "var = mean ** 2\n",
        "# Scale inputs to [-1, +1]\n",
        "x = norm_layer(x)\n",
        "norm_layer.set_weights([mean, var])\n",
        "\n",
        "# The base model contains batchnorm layers. We want to keep them in inference mode\n",
        "# when we unfreeze the base model for fine-tuning, so we make sure that the\n",
        "# base_model is running in inference mode here.\n",
        "x = base_model(x, training=False)\n",
        "x = keras.layers.GlobalAveragePooling2D()(x)\n",
        "x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout\n",
        "outputs = keras.layers.Dense(1)(x)\n",
        "model = keras.Model(inputs, outputs)\n",
        "\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2e8237de81e8"
      },
      "source": [
        "## 训练顶层"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9137b8daedad"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=keras.optimizers.Adam(),\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.BinaryAccuracy()],\n",
        ")\n",
        "\n",
        "epochs = 20\n",
        "model.fit(train_ds, epochs=epochs, validation_data=validation_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aa51d4562fa7"
      },
      "source": [
        "## 对整个模型进行一轮微调\n",
        "\n",
        "最后，我们解冻基础模型，并以较低的学习率端到端地训练整个模型。\n",
        "\n",
        "重要的是，尽管基础模型变得可训练，但在构建模型过程中，由于我们在调用该模型时传递了 `training=False`，因此它仍在推断模式下运行。这意味着内部的批次归一化层不会更新其批次统计信息。如果它们更新了这些统计信息，则会破坏该模型到目前为止所学习的表示。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3cc299505b72"
      },
      "outputs": [],
      "source": [
        "# Unfreeze the base_model. Note that it keeps running in inference mode\n",
        "# since we passed `training=False` when calling it. This means that\n",
        "# the batchnorm layers will not update their batch statistics.\n",
        "# This prevents the batchnorm layers from undoing all the training\n",
        "# we've done so far.\n",
        "base_model.trainable = True\n",
        "model.summary()\n",
        "\n",
        "model.compile(\n",
        "    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate\n",
        "    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[keras.metrics.BinaryAccuracy()],\n",
        ")\n",
        "\n",
        "epochs = 10\n",
        "model.fit(train_ds, epochs=epochs, validation_data=validation_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "afa73d989302"
      },
      "source": [
        "经过 10 个周期后，微调在这里为我们提供了出色的改进。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "transfer_learning.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
